use super::{Notify, Token, TokenLink};
use core::cmp::min;
use core::sync::atomic::{
    AtomicU16, AtomicU64,
    Ordering::{self, SeqCst},
};

pub(crate) struct AtomicStatus {
    status: AtomicU64,
    rd: TokenLink,
    wr: TokenLink,
}

impl AtomicStatus {
    pub(crate) const fn new() -> Self {
        Self {
            status: AtomicU64::new(0),
            rd: TokenLink::new(),
            wr: TokenLink::new(),
        }
    }

    pub(crate) fn load(&self, ord: Ordering) -> (u16, u16, u16) {
        let val = self.status.load(ord);
        ((val >> 48) as u16, (val >> 32) as u16, (val >> 16) as u16)
    }

    fn ridx(&self) -> u16 {
        (self.status.load(SeqCst) >> 48) as u16
    }

    fn widx(&self) -> u16 {
        (self.status.load(SeqCst) >> 32) as u16
    }

    pub(crate) fn update_ridx_with<N>(&self, last: u16, ridx: u16, max: u16, n: &N, token: &Token)
    where
        N: Notify,
    {
        let token = self.rd.lock(token, last, ridx);
        if token.is_null() {
            return;
        }
        let mut sorted = unsafe { &*token }.sort();
        let or = self.ridx();
        if or != last {
            assert!(less(or, last));
            let token = self.rd.unlock(sorted, or);
            if token.is_null() {
                return;
            }
            sorted = unsafe { &*token }.sort();
        }

        loop {
            let (last, ridx) = sorted.last_and_idx();
            let next = sorted.unlock();
            self.update_ridx(last, ridx, max, n);
            let token = self.rd.unlock(next, ridx);
            if token.is_null() {
                return;
            }
            sorted = unsafe { &*token }.sort();
        }
    }

    pub(crate) fn update_ridx<N>(&self, last: u16, ridx: u16, max: u16, n: &N)
    where
        N: Notify,
    {
        let mut old = self.status.load(SeqCst);
        loop {
            let (or, w, _, os, m) = unpark(old);
            assert_eq!(last, or);
            let cnt = count(ridx, w);
            let s = status(cnt, max);
            let new = if s == os {
                park(ridx, w, cnt, s, m)
            } else {
                park(ridx, w, cnt, s, 1)
            };

            match self.status.compare_exchange(old, new, SeqCst, SeqCst) {
                Ok(_) => {
                    if os != s && m == 0 {
                        self.notify(os, s, n);
                    }
                    return;
                }
                Err(val) => old = val,
            }
        }
    }

    pub(crate) fn update_widx_with<N>(&self, last: u16, widx: u16, max: u16, n: &N, token: &Token)
    where
        N: Notify,
    {
        let token = self.wr.lock(token, last, widx);
        if token.is_null() {
            return;
        }
        let mut sorted = unsafe { &*token }.sort();
        let ow = self.widx();
        if ow != last {
            assert!(less(ow, last));
            let token = self.wr.unlock(sorted, ow);
            if token.is_null() {
                return;
            }
            sorted = unsafe { &*token }.sort();
        }

        loop {
            let (last, widx) = sorted.last_and_idx();
            let next = sorted.unlock();
            self.update_widx(last, widx, max, n);
            let token = self.wr.unlock(next, widx);
            if token.is_null() {
                return;
            }
            sorted = unsafe { &*token }.sort();
        }
    }

    pub(crate) fn update_widx<N>(&self, last: u16, widx: u16, max: u16, n: &N)
    where
        N: Notify,
    {
        let mut old = self.status.load(SeqCst);
        loop {
            let (r, ow, _, os, m) = unpark(old);
            assert_eq!(last, ow);
            let cnt = count(r, widx);
            let s = status(cnt, max);
            let new = if s == os {
                park(r, widx, cnt, s, m)
            } else {
                park(r, widx, cnt, s, 1)
            };

            match self.status.compare_exchange(old, new, SeqCst, SeqCst) {
                Ok(_) => {
                    if os != s && m == 0 {
                        self.notify(os, s, n);
                    }
                    return;
                }
                Err(val) => old = val,
            }
        }
    }

    fn notify<N>(&self, os: u8, mut s: u8, n: &N)
    where
        N: Notify,
    {
        n.notify(os, s);
        let _ = self.status.fetch_update(SeqCst, SeqCst, |status| {
            let (r, w, cnt, ns, _) = unpark(status);
            if ns != s {
                n.notify(s, ns);
                s = ns;
            }
            Some(park(r, w, cnt, ns, 0))
        });
    }

    pub(crate) fn fetch_ridx(&self, cnt: u16) -> (u16, u16) {
        let (r, _, n) = self.load(SeqCst);
        let cnt = min(cnt, n);
        (r, cnt)
    }

    pub(crate) fn fetch_widx(&self, cnt: u16, max: u16) -> (u16, u16) {
        let (_, w, n) = self.load(SeqCst);
        let cnt = min(cnt, max - n);
        (w, cnt)
    }

    pub(crate) fn fetch_ridx_with(&self, idx: &AtomicU16, cnt: u16) -> (u16, u16) {
        let mut r = idx.load(SeqCst);
        loop {
            let (_, w, n) = self.load(SeqCst);
            if n > 0 {
                if less(r, w) {
                    let cnt = min(cnt, count(r, w));
                    assert!(cnt > 0);
                    match idx.compare_exchange(r, r.wrapping_add(cnt), SeqCst, SeqCst) {
                        Ok(_) => return (r, cnt),
                        Err(val) => r = val,
                    }
                }
                continue;
            }
            return (0, 0);
        }
    }

    pub(crate) fn fetch_widx_with(&self, idx: &AtomicU16, cnt: u16, max: u16) -> (u16, u16) {
        let mut w = idx.load(SeqCst);
        loop {
            let (r, _, n) = self.load(SeqCst);
            if n < max {
                let end = r.wrapping_add(max);
                if less(w, end) {
                    let cnt = min(cnt, count(w, end));
                    assert!(cnt > 0);
                    match idx.compare_exchange(w, w.wrapping_add(cnt), SeqCst, SeqCst) {
                        Ok(_) => return (w, cnt),
                        Err(val) => w = val,
                    }
                }
                continue;
            }
            return (0, 0);
        }
    }
}

pub(crate) fn count(r: u16, w: u16) -> u16 {
    let cnt = w.wrapping_sub(r);
    if cnt < (u16::MAX >> 1) {
        return cnt;
    }
    r.wrapping_sub(w)
}

pub(crate) fn less(lhs: u16, rhs: u16) -> bool {
    let diff = rhs.wrapping_sub(lhs);
    if diff < (u16::MAX >> 1) {
        return diff > 0;
    }
    false
}

fn unpark(status: u64) -> (u16, u16, u16, u8, u8) {
    (
        (status >> 48) as u16,
        (status >> 32) as u16,
        (status >> 16) as u16,
        (status >> 8) as u8,
        status as u8,
    )
}

fn park(r: u16, w: u16, cnt: u16, s: u8, n: u8) -> u64 {
    (r as u64) << 48 | (w as u64) << 32 | (cnt as u64) << 16 | (s as u64) << 8 | n as u64
}

fn status(cnt: u16, max: u16) -> u8 {
    if cnt > 0 {
        if cnt < max {
            PARTIAL
        } else {
            assert_eq!(cnt, max);
            FULL
        }
    } else {
        EMPTY
    }
}

pub(crate) const EMPTY: u8 = 0;
pub(crate) const PARTIAL: u8 = 1;
pub(crate) const FULL: u8 = 2;
